Skip to content

[PerfXLab] optimize fill performance#2216

Open
bin913 wants to merge 5 commits intoflagos-ai:masterfrom
bin913:fill
Open

[PerfXLab] optimize fill performance#2216
bin913 wants to merge 5 commits intoflagos-ai:masterfrom
bin913:fill

Conversation

@bin913
Copy link
Copy Markdown
Contributor

@bin913 bin913 commented Apr 2, 2026

PR Category

[ Operator]

Type of Change

[ Performance Optimization]

Description

optimize fill.fill_scalar_ performance for fill

Issue

Progress

  • Change is properly reviewed (1 reviewer required, 2 recommended).
  • Change is responded to an issue.
  • Change is fully covered by a UT.

Performance

Operator: fill_scalar_  Performance Test (dtype=torch.float16, mode=kernel,level=core)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.653856            0.653792               1.000          [torch.Size([1073741824]), 3.14159]
SUCCESS               0.005008            0.004992               1.003          [torch.Size([64, 64]), 3.14159]
SUCCESS               0.015104            0.015296               0.987          [torch.Size([4096, 4096]), 3.14159]
SUCCESS               0.015328            0.015328               1.000          [torch.Size([64, 512, 512]), 3.14159]
SUCCESS               0.654080            0.654224               1.000          [torch.Size([1024, 1024, 1024]), 3.14159]


Operator: fill_scalar_  Performance Test (dtype=torch.float32, mode=kernel,level=core)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               1.305952            1.303360               1.002          [torch.Size([1073741824]), 3.14159]
SUCCESS               0.005120            0.005024               1.019          [torch.Size([64, 64]), 3.14159]
SUCCESS               0.025280            0.025376               0.996          [torch.Size([4096, 4096]), 3.14159]
SUCCESS               0.025120            0.025216               0.996          [torch.Size([64, 512, 512]), 3.14159]
SUCCESS               1.306016            1.303264               1.002          [torch.Size([1024, 1024, 1024]), 3.14159]


Operator: fill_scalar_  Performance Test (dtype=torch.bfloat16, mode=kernel,level=core)
Status       Torch Latency (ms)    Gems Latency (ms)         Gems Speedup          Size Detail
-----------------------------------------------------------------------------------------------
SUCCESS               0.654176            0.654080               1.000          [torch.Size([1073741824]), 3.14159]
SUCCESS               0.004992            0.005152               0.969          [torch.Size([64, 64]), 3.14159]
SUCCESS               0.015232            0.015296               0.996          [torch.Size([4096, 4096]), 3.14159]
SUCCESS               0.015488            0.015328               1.010          [torch.Size([64, 512, 512]), 3.14159]
SUCCESS               0.653888            0.653904               1.000          [torch.Size([1024, 1024, 1024]), 3.14159]

# tensor constructor with given value
("fill_", torch.fill_, fill_input_fn),
("fill_scalar_", torch.ops.aten.fill_.Scalar, fill_input_fn),
# ("fill_scalar_", flag_gems.ops.fill.fill_scalar_, fill_input_fn),
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the FlagGems benchmark for fill_scalar_ commented out?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, line 194 shoude be removed. I will removo it.



def fill_scalar(input, value):
logger.debug("GEMS FILL_SCALAR HOPPER")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GEMS_HOPPER FILL_SCALAR



def fill_scalar_out(input, value, *, out=None):
logger.debug("GEMS FILL_SCALAR_OUT HOPPER")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

GEMS_HOPPER FILL_SCALAR_OUT

def fill_tensor(input, value):
if not value.is_cuda:
return fill_scalar(input, value.item())
logger.debug("GEMS FILL_TENSOR HOPPER")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug("GEMS FILL_TENSOR HOPPER")
logger.debug("GEMS_HOPPER FILL_TENSOR")



def fill_tensor_out(input, value, *, out=None):
logger.debug("GEMS FILL_TENSOR_OUT HOPPER")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug("GEMS FILL_TENSOR_OUT HOPPER")
logger.debug("GEMS_HOPPER FILL_TENSOR_OUT")

def fill_tensor_(self, value):
if not value.is_cuda:
return fill_scalar_(self, value.item())
logger.debug("GEMS FILL_TENSOR_ HOPPER")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug("GEMS FILL_TENSOR_ HOPPER")
logger.debug("GEMS_HOPPER FILL_TENSOR_")



def fill_scalar_(self, value):
logger.debug("GEMS FILL_SCALAR_ HOPPER")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
logger.debug("GEMS FILL_SCALAR_ HOPPER")
logger.debug("GEMS_HOPPER FILL_SCALAR_")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants